Skip to content

Coalesce B-scale reads for dynamic-dim MXFP4 preshuffle kernels#1058

Open
Hardcode84 wants to merge 15 commits intoiree-org:mainfrom
Hardcode84:coalesce_b_scale_nonmultiple_blkN_2
Open

Coalesce B-scale reads for dynamic-dim MXFP4 preshuffle kernels#1058
Hardcode84 wants to merge 15 commits intoiree-org:mainfrom
Hardcode84:coalesce_b_scale_nonmultiple_blkN_2

Conversation

@Hardcode84
Copy link
Contributor

Problem

When M, N, K are dynamic, the read coalescer in partition_strided_operators.py fails to merge the 16 per-thread B-scale byte reads into a single contiguous vector<16xi8> load. Instead it produces fragmented loads of mixed widths ({2, 16, 8, 4}) stitched together with vector.from_elements.

Root cause: the pairwise merge verifies candidate pairs by computing per-dim offset diffs at multiple numeric probe points. For B-scale buffers with shape [N, K/2], the 2D decomposition row = offset floordiv K/2, col = offset mod K/2 gives inconsistent diffs when probe values (e.g. K=137 → K/2=68) don't respect the kernel's divisibility constraints (K % 256 == 0). The verification then rejects perfectly valid merge candidates.

With static dims this doesn't happen because K is concrete and the decomposition is trivially consistent.

Solution

  1. Divisibility substitutions before probing: plumb get_divisibility_subs(constraints) into the merge pipeline. Before numeric probing, apply forward subs like K → 256*K' so that floordiv/Mod evaluate consistently across all probe sets. This is the key fix — with it, B-scale reads coalesce into 8 clean vector<16xi8> loads identical to the static-dims case.

  2. Numeric probing for pairwise merge: replace the old symbolic diff approach (which exploded on complex preshuffle index expressions) with concrete numeric evaluation using three linear generators that avoid pathological values.

  3. Re-merging across mask levels: allow reads that already carry precomputed masks (from prior merge rounds) to participate in further merging. The mask is extended to cover the wider result and the precomputed condition is preserved.

  4. Bounds pre-flattening: pre-compute each read's bounds check as a flat sympy boolean before merging, so the coalescer can reason about mask compatibility without re-deriving bounds at each merge step.

xintin and others added 15 commits March 5, 2026 21:02
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
…explosion

The old _pairwise_merge used O(n²) symbolic diff resolution via
sympy.lambdify, which hangs on huge preshuffle index expressions
(postorder_traversal of the diff tree never completes). Replace with
xreplace-based numeric evaluation of each offset independently, dict
lookup for O(1) candidate matching, and verification across multiple
probe value sets. Fixes dynamic preshuffle MXFP4 GEMM compilation
hanging in merge_contiguous_reads (128 reads now merge in ~1s).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Previously, _group_reads_by_memory skipped reads with
precomputed_mask_expr, preventing merged ept=2 reads from being
further merged to ept=4/8/16. Fix by removing the skip and remapping
the sub-read's iota symbol ($IOTA{old_size} -> $IOTA{wide_ept} -
offset) when composing masks in _build_wide_mask_expr.

Result: dynamic preshuffle MXFP4 b-tensor loads go from 332
vector<2xi8> + 84 vector<16xi8> to 8 vector<2xi8> + 120
vector<16xi8>.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
…scing

The pairwise merge uses numeric probing to verify that adjacent reads
have consistent per-dim offset diffs across multiple probe points.
With symbolic K, the 2D decomposition (row = offset floordiv K/2,
col = offset mod K/2) gives inconsistent diffs when probe values
don't respect divisibility constraints — e.g. at K=137, K/2=68,
adjacent bytes straddle a row boundary that doesn't exist at K=256.

Fix by applying divisibility forward subs (K -> 256*K') before
probing, so floordiv/Mod evaluate consistently across all probes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Tests use actual B-scale preshuffle index expressions (row = floor(offset
/ (K/2)), col = offset mod (K/2)) to verify that:
- Flat offset diffs are always correct regardless of probe values.
- Per-dim diffs are inconsistent without divisibility subs (the bug).
- Per-dim diffs become consistent after K -> 256*K' substitution.
- _find_merge_dim_from_diffs correctly identifies the merge dimension.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Verifies that divisibility substitutions (K % 256) enable the read
coalescer to produce clean vector<16xi8> B-scale and vector<4xi8>
A-scale loads from fat_raw_buffer, with no vector.from_elements
fragmentation, when M, N, K are dynamic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
# CHECK: amdgpu.fat_raw_buffer_cast

# B-scale reads are clean vector<16xi8> from fat_raw_buffer — no
# fragmentation into mixed-width loads glued by from_elements.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

return groups


def _resolve_symbolic_diff(raw_diff, has_complex_mapping, expected_vals=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this dead code?

elems_per_reg = 32 // elem_bits

reg_offset = offset_val // elems_per_reg
reg_count = max(1, (size_val * elem_bits + 31) // 32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a ceildiv function if that is what this doing?


def _pairwise_merge(
read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint, divisibility_fwd=None
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of overlap between this and multiway_merge. Can you refactor these 2 functions? In theory they could both use a common ProbeEvaluator + ReadInfo class, where ProbeEvaluator implements verify_diff, offset_map, etc.

hw_constraint,
divisibility_fwd=None,
):
"""Coalesce unmerged ept==1 reads whose flat offsets fall in an aligned window.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The probe-based approach computes num_offs once (O(n)), but then for each anchor, iterates over all probes:

for anchor_idx in range(len(unmerged_infos)):
    ...
    for probe_idx in range(len(unmerged_infos)):

With the same dict-lookup pattern used for _pairwise_merge, the inner loop could check offset_map[target] for each window position instead of scanning all probes.

consistent = True
for ep in extra_probes:
try:
va = _eval_expr(resolved_offs[anchor_idx], ep)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also could do with more expressive variable names.

break
if not consistent:
continue
_, custom_p, node_p = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just unpack unmerged_infos directly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants